'''
http://www.opensource.org/licenses/mit-license.php

Copyright 2007-2011 David Alan Cridland
Copyright 2011 Lance Stout
Copyright 2012 Tyler L Hobbs

Permission is hereby granted, free of charge, to any person obtaining a copy of this
software and associated documentation files (the "Software"), to deal in the Software
without restriction, including without limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons
to whom the Software is furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all copies or
substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,
INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR
PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE
FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR
OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
'''
# This file was generated by referring test cases from the pure-sasl repo i.e. https://github.com/thobbs/pure-sasl/tree/master/tests/unit 
# and by refactoring them to cover wrapper functions in sasl_compat.py along with added coverage for functions exclusive to sasl_compat.py.

import unittest
import base64
import hashlib
import hmac
import kerberos
from mock import patch
import six
import struct
from puresasl import SASLProtocolException, QOP
from puresasl.client import SASLError
from pyhive.sasl_compat import PureSASLClient, error_catcher


class TestPureSASLClient(unittest.TestCase):
    """Test cases for initialization of SASL client using PureSASLClient class"""

    def setUp(self):
        self.sasl_kwargs = {}
        self.sasl = PureSASLClient('localhost', **self.sasl_kwargs)

    def test_start_no_mechanism(self):
        """Test starting SASL authentication with no mechanism."""
        success, mechanism, response = self.sasl.start(mechanism=None)
        self.assertFalse(success)
        self.assertIsNone(mechanism)
        self.assertIsNone(response)
        self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')        

    def test_start_wrong_mechanism(self):
        """Test starting SASL authentication with a single unsupported mechanism."""
        success, mechanism, response = self.sasl.start(mechanism='WRONG')
        self.assertFalse(success)
        self.assertEqual(mechanism, 'WRONG')
        self.assertIsNone(response)
        self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')

    def test_start_list_of_invalid_mechanisms(self):
        """Test starting SASL authentication with a list of unsupported mechanisms."""
        self.sasl.start(['invalid1', 'invalid2'])
        self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')

    def test_start_list_of_valid_mechanisms(self):
        """Test starting SASL authentication with a list of supported mechanisms."""
        self.sasl.start(['PLAIN', 'DIGEST-MD5', 'CRAM-MD5'])
        # Validate right mechanism is chosen based on score.
        self.assertEqual(self.sasl._chosen_mech.name, 'DIGEST-MD5')

    def test_error_catcher_no_error(self):
        """Test the error_catcher with no error."""
        with error_catcher(self.sasl):
            result, _, _ = self.sasl.start(mechanism='ANONYMOUS')

        self.assertEqual(self.sasl.getError(), None)
        self.assertEqual(result, True)

    def test_error_catcher_with_error(self):
        """Test the error_catcher with an error."""
        with error_catcher(self.sasl):
            result, _, _ = self.sasl.start(mechanism='WRONG')

        self.assertEqual(result, False)
        self.assertEqual(self.sasl.getError(), 'None of the mechanisms listed meet all required properties')

"""Assuming Client initilization went well and a mechanism is chosen, Below are the test cases for different mechanims"""

class _BaseMechanismTests(unittest.TestCase):
    """Base test case for SASL mechanisms."""

    mechanism = 'ANONYMOUS'
    sasl_kwargs = {}

    def setUp(self):
        self.sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs)
        self.mechanism_class = self.sasl._chosen_mech

    def test_init_basic(self, *args):
        sasl = PureSASLClient('localhost', mechanism=self.mechanism, **self.sasl_kwargs)
        mech = sasl._chosen_mech
        self.assertIs(mech.sasl, sasl)

    def test_step_basic(self, *args):
        success, response = self.sasl.step(six.b('string'))
        self.assertTrue(success)
        self.assertIsInstance(response, six.binary_type)

    def test_decode_encode(self, *args):
        self.assertEqual(self.sasl.encode('msg'), (False, None))
        self.assertEqual(self.sasl.getError(), '')
        self.assertEqual(self.sasl.decode('msg'), (False, None))
        self.assertEqual(self.sasl.getError(), '')


class AnonymousMechanismTest(_BaseMechanismTests):
    """Test case for the Anonymous SASL mechanism."""

    mechanism = 'ANONYMOUS'


class PlainTextMechanismTest(_BaseMechanismTests):
    """Test case for the PlainText SASL mechanism."""

    mechanism = 'PLAIN'
    username = 'user'
    password = 'pass'
    sasl_kwargs = {'username': username, 'password': password}

    def test_step(self):
        for challenge in (None, '', b'asdf', u"\U0001F44D"):
            success, response = self.sasl.step(challenge)
            self.assertTrue(success)
            self.assertEqual(response, six.b(f'\x00{self.username}\x00{self.password}'))
            self.assertIsInstance(response, six.binary_type)

    def test_step_with_authorization_id_or_identity(self):
        challenge = u"\U0001F44D"
        identity = 'user2'

        # Test that we can pass an identity
        sasl_kwargs = self.sasl_kwargs.copy()
        sasl_kwargs.update({'identity': identity})
        sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs)
        success, response = sasl.step(challenge)
        self.assertTrue(success)
        self.assertEqual(response, six.b(f'{identity}\x00{self.username}\x00{self.password}'))
        self.assertIsInstance(response, six.binary_type)
        self.assertTrue(sasl.complete)

        # Test that the sasl authorization_id has priority over identity
        auth_id = 'user3'
        sasl_kwargs.update({'authorization_id': auth_id})
        sasl = PureSASLClient('localhost', mechanism=self.mechanism, **sasl_kwargs)
        success, response = sasl.step(challenge)
        self.assertTrue(success)
        self.assertEqual(response, six.b(f'{auth_id}\x00{self.username}\x00{self.password}'))
        self.assertIsInstance(response, six.binary_type)
        self.assertTrue(sasl.complete)

    def test_decode_encode(self):
        msg = 'msg'
        self.assertEqual(self.sasl.decode(msg), (True, msg))
        self.assertEqual(self.sasl.encode(msg), (True, msg))


class ExternalMechanismTest(_BaseMechanismTests):
    """Test case for the External SASL mechanisms"""

    mechanism = 'EXTERNAL'

    def test_step(self):
        self.assertEqual(self.sasl.step(), (True, b''))

    def test_decode_encode(self):
        msg = 'msg'
        self.assertEqual(self.sasl.decode(msg), (True, msg))
        self.assertEqual(self.sasl.encode(msg), (True, msg))


@patch('puresasl.mechanisms.kerberos.authGSSClientStep')
@patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=base64.b64encode(six.b('some\x00 response')))
class GSSAPIMechanismTest(_BaseMechanismTests):
    """Test case for the GSSAPI SASL mechanism."""

    mechanism = 'GSSAPI'
    service = 'GSSAPI'
    sasl_kwargs = {'service': service}

    @patch('puresasl.mechanisms.kerberos.authGSSClientWrap')
    @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap')
    def test_decode_encode(self, _inner1, _inner2, authGSSClientResponse, *args):
        # bypassing step setup by setting qop directly
        self.mechanism_class.qop = QOP.AUTH
        msg = b'msg'
        self.assertEqual(self.sasl.decode(msg), (True, msg))
        self.assertEqual(self.sasl.encode(msg), (True, msg))

        # Test for behavior with different QOP like data integrity and confidentiality for Kerberos authentication 
        for qop in (QOP.AUTH_INT, QOP.AUTH_CONF):
            self.mechanism_class.qop = qop
            with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=1):
                self.assertEqual(self.sasl.decode(msg), (True, base64.b64decode(authGSSClientResponse.return_value)))
                self.assertEqual(self.sasl.encode(msg), (True, base64.b64decode(authGSSClientResponse.return_value)))
            if qop == QOP.AUTH_CONF:
                with patch('puresasl.mechanisms.kerberos.authGSSClientResponseConf', return_value=0):
                    self.assertEqual(self.sasl.encode(msg), (False, None))
                    self.assertEqual(self.sasl.getError(), 'Error: confidentiality requested, but not honored by the server.')

    def test_step_no_user(self, authGSSClientResponse, *args):
        msg = six.b('whatever')

        # no user
        self.assertEqual(self.sasl.step(msg), (True, base64.b64decode(authGSSClientResponse.return_value)))
        with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=''):
            self.assertEqual(self.sasl.step(msg), (True, six.b('')))

        username = 'username'
        # with user; this has to be last because it sets mechanism.user
        with patch('puresasl.mechanisms.kerberos.authGSSClientStep', return_value=kerberos.AUTH_GSS_COMPLETE):
            with patch('puresasl.mechanisms.kerberos.authGSSClientUserName', return_value=six.b(username)):
                self.assertEqual(self.sasl.step(msg), (True, six.b('')))
                self.assertEqual(self.mechanism_class.user, six.b(username))

    @patch('puresasl.mechanisms.kerberos.authGSSClientUnwrap')
    def test_step_qop(self, *args):
        self.mechanism_class._have_negotiated_details = True
        self.mechanism_class.user = 'user'
        msg = six.b('msg')
        self.assertEqual(self.sasl.step(msg), (False, None))
        self.assertEqual(self.sasl.getError(), 'Bad response from server')

        max_len = 100
        self.assertLess(max_len, self.sasl.max_buffer)
        for i, qop in QOP.bit_map.items():
            qop_size = struct.pack('!i', i << 24 | max_len)
            response = base64.b64encode(qop_size)
            with patch('puresasl.mechanisms.kerberos.authGSSClientResponse', return_value=response):
                with patch('puresasl.mechanisms.kerberos.authGSSClientWrap') as authGSSClientWrap:
                    self.mechanism_class.complete = False
                    self.assertEqual(self.sasl.step(msg), (True, qop_size))
                    self.assertTrue(self.mechanism_class.complete)
                    self.assertEqual(self.mechanism_class.qop, qop)
                    self.assertEqual(self.mechanism_class.max_buffer, max_len)

                    args = authGSSClientWrap.call_args[0]
                    out_data = args[1]
                    out = base64.b64decode(out_data)
                    self.assertEqual(out[:4], qop_size)
                    self.assertEqual(out[4:], six.b(self.mechanism_class.user))


class CramMD5MechanismTest(_BaseMechanismTests):
    """Test case for the CRAM-MD5 SASL mechanism."""

    mechanism = 'CRAM-MD5'
    username = 'user'
    password = 'pass'
    sasl_kwargs = {'username': username, 'password': password}

    def test_step(self):
        success, response = self.sasl.step(None)
        self.assertTrue(success)
        self.assertIsNone(response)
        challenge = six.b('msg')
        hash = hmac.HMAC(key=six.b(self.password), digestmod=hashlib.md5)
        hash.update(challenge)
        success, response = self.sasl.step(challenge)
        self.assertTrue(success)
        self.assertIn(six.b(self.username), response)
        self.assertIn(six.b(hash.hexdigest()), response)
        self.assertIsInstance(response, six.binary_type)
        self.assertTrue(self.sasl.complete)

    def test_decode_encode(self):
        msg = 'msg'
        self.assertEqual(self.sasl.decode(msg), (True, msg))
        self.assertEqual(self.sasl.encode(msg), (True, msg))


class DigestMD5MechanismTest(_BaseMechanismTests):
    """Test case for the DIGEST-MD5 SASL mechanism."""

    mechanism = 'DIGEST-MD5'
    username = 'user'
    password = 'pass'
    sasl_kwargs = {'username': username, 'password': password}

    def test_decode_encode(self):
        msg = 'msg'
        self.assertEqual(self.sasl.decode(msg), (True, msg))
        self.assertEqual(self.sasl.encode(msg), (True, msg))

    def test_step_basic(self, *args):
        pass

    def test_step(self):
        """Test a SASL step with dummy challenge for DIGEST-MD5 mechanism."""
        testChallenge = (
            b'nonce="rmD6R8aMYVWH+/ih9HGBr3xNGAR6o2DUxpKlgDz6gUQ=",r'
            b'ealm="example.org",qop="auth,auth-int,auth-conf",cipher="rc4-40,rc'
            b'4-56,rc4,des,3des",maxbuf=65536,charset=utf-8,algorithm=md5-sess'
        )
        result, response = self.sasl.step(testChallenge)
        self.assertTrue(result)
        self.assertIsNotNone(response)

    def test_step_server_answer(self):
        """Test a SASL step with a proper server answer for DIGEST-MD5 mechanism."""
        sasl_kwargs = {'username': "chris", 'password': "secret"}
        sasl = PureSASLClient('elwood.innosoft.com',
                        service="imap",
                        mechanism=self.mechanism,
                        mutual_auth=True,
                        **sasl_kwargs)
        testChallenge = (
            b'utf-8,username="chris",realm="elwood.innosoft.com",'
            b'nonce="OA6MG9tEQGm2hh",nc=00000001,cnonce="OA6MHXh6VqTrRk",'
            b'digest-uri="imap/elwood.innosoft.com",'
            b'response=d388dad90d4bbd760a152321f2143af7,qop=auth'
        )
        sasl.step(testChallenge)
        sasl._chosen_mech.cnonce = b"OA6MHXh6VqTrRk"

        serverResponse = (
            b'rspauth=ea40f60335c427b5527b84dbabcdfffd'
        )
        sasl.step(serverResponse)
        # assert that step choses the only supported QOP for for DIGEST-MD5 
        self.assertEqual(self.sasl.qop, QOP.AUTH)
